# This script works with run_simulation.slurm to execute a single simulation from a slurm array.
## Usage: Rscript run_simulation.R

# GLOBAL VARIABLES
N_SIMS_PER_DESIGN <- Sys.getenv("N_SIMS_PER_DESIGN") |> as.numeric()
DEV <- Sys.getenv("DEV") |> as.logical()
SLURM_ARRAY_ID <- Sys.getenv("SLURM_ARRAY_TASK_ID") |> as.numeric()
SIM_ID <- Sys.getenv("SIM_ID") |> as.numeric()

# Source simulation_utils
setwd("/projects/BSTEWART/dsl-image/dsl-image-simulation")
source("src/simulation_utils.R")

main <- function(simulation, slurm_array_id=1){
  # Define out file
  out_file <- paste0("results/sim", simulation, "/",  sprintf("%03d", slurm_array_id), ".rds")
  ## Check if exists; if so then exit:
  if (file.exists(out_file)) {
      print(paste("File exists:", out_file))
      return()
  }

  ## Common to all simulations
  src_data <- load_data(dev=DEV)
  fml <- formula("violence ~ sign + photo + fire + police + children + group_20 + flag + night + shouting")
  surrogate_var <- "fire"
  estimator_name <- "lm"

    
  if (simulation==1) {
    # Plot 1: Showing we need DSL.
    # For each of the three metrics (bias, RMSE, coverage).  I want a version of
    # the last column (gold standard accuracy = 100%) tracing along the x-axis
    # the accuracy of the surrogate with lines for different sample sizes. You
    # have run this already for what I'm going to call the "benign" process.  I
    # want a version now that has a "malign" process.  That is the errors are
    # correlated with the variable of interest in some way.  The key is that our
    # other papers use a quite adversarial process.  We don't need to go that
    # adversarial in presenting the result but I think we also need to show both
    # the most benign case and something that is more realistically adversarial.
    # I'll leave it to you how to pick that process (although I would err
    # towards something simple like different probabilities of flipping based on
    # variables of interest).
    
    ## Simulation 1:
    # - gs_acc = 1.0
    # - q error function of other variables - modified simulation function to take sim_data as argument

    ## Design grid
    design_grid <- expand.grid(
        n_label = c(100, 250, 500, 750, 1000, 2000),
        gs_acc = c(1.0),
        q_acc = c(0.5, 0.75, 0.9, 0.95, 0.99, 1.0)
    ) # 36 designs
    design <- design_grid[slurm_array_id,]
    print(design)

    ## Define error functions
    gs_error_func <- function(sim_data){return(sim_data[,surrogate_var])}
    q_error_func <- function(sim_data, acc=design$q_acc){
      # Create error propensity as linear combination of covariates
      error_ps_covs <- c("sign", "police", "flag", "night", "shouting")
      error_ps_coefs <- rnorm(length(error_ps_covs), mean=1, sd=10.0)
      error_ps <- as.matrix(sim_data[,error_ps_covs]) %*% error_ps_coefs
      error_ps_norm <- (error_ps - min(error_ps) + 1e-5) / sum(error_ps - min(error_ps) + 1e-5) # Normalize to add up to 1
      
      # Sample errors
      error_size <- round(nrow(sim_data) * (1 - acc))
      error_idxs <- sample(1:nrow(sim_data), size=error_size, replace=FALSE, prob=error_ps_norm)
      
      # Create surrogate with errors
      Q <- sim_data[,surrogate_var]
      Q[error_idxs] <- 1 - Q[error_idxs]
      return(Q) 
    } 

  } else if (simulation==2) {
    # Plot 2: Showing that GS errors aren't a huge, huge deal.
    # Again I want a variant for each of three metrics.  For gold standard
    # errors, we use the benign process.  For surrogate accuracy we use whatever
    # new malign process you come up with.  Each plot should use one surrogate
    # accuracy level, but I'd love to see options for 75%, 80% and 85% accurate
    # (which seem reasonable to me as surrogate accuracies) and then I'll choose
    # one.  Again each metric-surrogate_accuracy combination should ideally be
    # one plot.  I'll let you figure out how to do that, but I think the key is
    # probably that gold standard accuracy is on the x-axis.  When plotting I'd
    # run it backwards though (so 100% accurate gold standard is on the far left
    # and human error increases as you move right).
    
    ## Simulation 2:
    # - gs_acc benign process
    # - q_acc malign process: 75, 80, 85
    design_grid <- expand.grid(
        n_label = c(100, 250, 500, 750, 1000, 2000),
        gs_acc = c(0.5, 0.75, 0.9, 0.95, 0.99, 1.0),
        q_acc = c(0.75, 0.8, 0.85)
    ) # 108 designs
    gs_error_func <- function(sim_data, acc=design$gs_acc){
      random_binary_error(sim_data[,surrogate_var], acc)
    }
    q_error_func <- function(sim_data, acc=design$q_acc){
      # Create error propensity as linear combination of covariates
      error_ps_covs <- c("sign", "police", "flag", "night", "shouting")
      error_ps_coefs <- rnorm(length(error_ps_covs), mean=0, sd=1.0)
      error_ps <- as.matrix(sim_data[,error_ps_covs]) %*% error_ps_coefs
      error_ps_norm <- (error_ps - min(error_ps)) / sum(error_ps - min(error_ps)) # Normalize to add up to 1
      
      # Sample errors
      error_size <- round(nrow(sim_data) * (1 - acc))
      error_idxs <- sample(1:nrow(sim_data), size=error_size, replace=FALSE, prob=error_ps_norm)
      
      # Create surrogate with errors
      surrogate <- sim_data[,surrogate_var]
      surrogate[error_idxs] <- 1 - surrogate[error_idxs]
      return(surrogate) 
    }
    
    
  } else {
    stop("Invalid simulation number")
  }
  
  ## Run simulation
  sim_result <- repeat_simulation(
      src_data = src_data,
      fml = fml,
      surrogate_var = surrogate_var,
      n_label = design$n_label,
      estimator_name = estimator_name,
      q_error_func = q_error_func,
      gs_error_func = gs_error_func,
      n_sims = N_SIMS_PER_DESIGN
  )

  ## Save results
  saveRDS(sim_result, file=out_file)
}

# Run main
main(simulation=SIM_ID, slurm_array_id=SLURM_ARRAY_ID)
